import unittest
from io import StringIO, TextIOBase
from typing import List, Dict


class HW_OJ_20210929(unittest.TestCase):
    @staticmethod
    def get_deep(nodes: List[Dict], _idx: int) -> int:
        node = nodes[_idx]
        if len(node['children']):
            children_deep = [HW_OJ_20210929.get_deep(nodes, each_) for each_ in node['children']]
            return max(children_deep) + 1
        else:
            return 1

    @staticmethod
    def max_link_length(_input: TextIOBase) -> int:
        # 解析输入构造图
        edge_size = int(_input.readline().strip())
        target_node = _input.readline().strip()

        nodes_index = {}
        nodes_index_cursor = 0
        links = []
        for _each in _input:
            device, start, end = _each.strip().split(' ')
            links.append((device, start, end))
            for name in [device, start, end]:
                if name not in nodes_index and name != 'null':
                    nodes_index[name] = nodes_index_cursor
                    nodes_index_cursor += 1
            if len(links) >= edge_size:
                break

        nodes = [{'name': None, 'parent': None, 'children': []} for _ in range(nodes_index_cursor)]
        for name, idx in nodes_index.items():
            nodes[idx]['name'] = name

        for device, start, end in links:
            device_id = nodes_index[device]
            nodes[device_id]['name'] = device
            if start != 'null':
                nodes[device_id]['parent'] = nodes_index[start]
            if end != 'null':
                nodes[device_id]['children'].append(nodes_index[end])
                if not nodes[nodes_index[end]]['parent']:
                    nodes[nodes_index[end]]['parent'] = device_id

        deep_of_target_tree = HW_OJ_20210929.get_deep(nodes, nodes_index[target_node])
        target = nodes_index[target_node]
        parent_deep = 0
        while nodes[target]['parent'] is not None:
            parent_deep += 1
            target = nodes[target]['parent']

        return parent_deep + deep_of_target_tree

    def test_2nd_question(self):
        input_and_output = [
            ("""8
d3
d1 null d2
d2 d1 d3
d2 d1 d5
d2 d1 d4
d3 d2 null
d4 d2 null
d5 d2 d6
d8 d5 null""", 3),
            ("""8
        d2
        d1 null d2
        d2 d1 d3
        d2 d1 d5
        d2 d1 d4
        d3 d2 null
        d4 d2 null
        d5 d2 d6
        d8 d5 null""", 4),
            ("""8
        d5
        d1 null d2
        d2 d1 d3
        d2 d1 d5
        d2 d1 d4
        d3 d2 null
        d4 d2 null
        d5 d2 d6
        d8 d5 null""", 4),
            ("""4
device1
device1 device2 device3
device2 null device1
device3 device1 device4
device4 device3 null""", 4)
        ]

        for input0, output0 in input_and_output:
            ss = HW_OJ_20210929.max_link_length(StringIO(input0))
            print('max link length:', ss)
            assert HW_OJ_20210929.max_link_length(StringIO(input0)) == output0

    @staticmethod
    def sub_interval_area(arr: List[int]):
        """
        arr, 递增的数组(非严格)
        :return: 子区间面积和
        """
        # 根据推证(如图), 需要有顺序(ai<=aj if i <j), 区间长度公式为(n+2)sum{a_i*i} - (n+2)*(n-1)/2*sum{a_i}
        n = len(arr)
        sum1 = 0
        sum2 = 0
        for i, a_i in enumerate(arr):
            sum1 += a_i * i
            sum2 += a_i
        return (n + 2) * sum1 - (n + 2) * (n - 1) // 2 * sum2

    def test_3rd_question(self):
        lines = StringIO("""30 20 10 5 3 1
        10 100 1000
        1 4 2000""")
        expected_output = [804, 4950, 9995]
        for line, output in zip(lines, expected_output):
            arr = list(map(int, line.strip().split(' ')))
            arr.sort()

            res = HW_OJ_20210929.sub_interval_area(arr)
            print(res)
            assert HW_OJ_20210929.sub_interval_area(arr) == output


if __name__ == '__main__':
    unittest.main()

